{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3362a434",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset chn_senti_corp (/Users/lee/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(9600,\n",
       " ('选择珠江花园的原因就是方便，有电动扶梯直接到达海边，周围餐馆、食廊、商场、超市、摊位一应俱全。酒店装修一般，但还算整洁。 泳池在大堂的屋顶，因此很小，不过女儿倒是喜欢。 包的早餐是西式的，还算丰富。 服务吗，一般',\n",
       "  1))"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import torch\n",
    "from datasets import load_dataset\n",
    "\n",
    "\n",
    "#定义数据集\n",
    "class Dataset(torch.utils.data.Dataset):\n",
    "    def __init__(self, split):\n",
    "        self.dataset = load_dataset(path='seamew/ChnSentiCorp', split=split)\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.dataset)\n",
    "\n",
    "    def __getitem__(self, i):\n",
    "        text = self.dataset[i]['text']\n",
    "        label = self.dataset[i]['label']\n",
    "\n",
    "        return text, label\n",
    "\n",
    "\n",
    "dataset = Dataset('train')\n",
    "\n",
    "len(dataset), dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e70a58c6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PreTrainedTokenizer(name_or_path='bert-base-chinese', vocab_size=21128, model_max_len=512, is_fast=False, padding_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'})"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertTokenizer\n",
    "\n",
    "#加载字典和分词工具\n",
    "token = BertTokenizer.from_pretrained('bert-base-chinese')\n",
    "\n",
    "token"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e59695a4",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "600\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "(torch.Size([16, 500]),\n",
       " torch.Size([16, 500]),\n",
       " torch.Size([16, 500]),\n",
       " tensor([1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1]))"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def collate_fn(data):\n",
    "    sents = [i[0] for i in data]\n",
    "    labels = [i[1] for i in data]\n",
    "\n",
    "    #编码\n",
    "    data = token.batch_encode_plus(batch_text_or_text_pairs=sents,\n",
    "                                   truncation=True,\n",
    "                                   padding='max_length',\n",
    "                                   max_length=500,\n",
    "                                   return_tensors='pt',\n",
    "                                   return_length=True)\n",
    "\n",
    "    #input_ids:编码之后的数字\n",
    "    #attention_mask:是补零的位置是0,其他位置是1\n",
    "    input_ids = data['input_ids']\n",
    "    attention_mask = data['attention_mask']\n",
    "    token_type_ids = data['token_type_ids']\n",
    "    labels = torch.LongTensor(labels)\n",
    "\n",
    "    #print(data['length'], data['length'].max())\n",
    "\n",
    "    return input_ids, attention_mask, token_type_ids, labels\n",
    "\n",
    "\n",
    "#数据加载器\n",
    "loader = torch.utils.data.DataLoader(dataset=dataset,\n",
    "                                     batch_size=16,\n",
    "                                     collate_fn=collate_fn,\n",
    "                                     shuffle=True,\n",
    "                                     drop_last=True)\n",
    "\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    break\n",
    "\n",
    "print(len(loader))\n",
    "input_ids.shape, attention_mask.shape, token_type_ids.shape, labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "f620d0e3",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertModel: ['cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 500, 768])"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import BertModel\n",
    "\n",
    "#加载预训练模型\n",
    "pretrained = BertModel.from_pretrained('bert-base-chinese')\n",
    "\n",
    "#不训练,不需要计算梯度\n",
    "for param in pretrained.parameters():\n",
    "    param.requires_grad_(False)\n",
    "\n",
    "#模型试算\n",
    "out = pretrained(input_ids=input_ids,\n",
    "           attention_mask=attention_mask,\n",
    "           token_type_ids=token_type_ids)\n",
    "\n",
    "out.last_hidden_state.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5d3d02a2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([16, 2])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#定义下游任务模型\n",
    "class Model(torch.nn.Module):\n",
    "    def __init__(self):\n",
    "        super().__init__()\n",
    "        self.fc = torch.nn.Linear(768, 2)\n",
    "\n",
    "    def forward(self, input_ids, attention_mask, token_type_ids):\n",
    "        with torch.no_grad():\n",
    "            out = pretrained(input_ids=input_ids,\n",
    "                       attention_mask=attention_mask,\n",
    "                       token_type_ids=token_type_ids)\n",
    "\n",
    "        out = self.fc(out.last_hidden_state[:, 0])\n",
    "\n",
    "        out = out.softmax(dim=1)\n",
    "\n",
    "        return out\n",
    "\n",
    "\n",
    "model = Model()\n",
    "\n",
    "model(input_ids=input_ids,\n",
    "      attention_mask=attention_mask,\n",
    "      token_type_ids=token_type_ids).shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "1bd44a7c",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.6782209873199463 0.625\n",
      "5 0.5840441584587097 0.75\n",
      "10 0.6991336941719055 0.375\n",
      "15 0.6459898948669434 0.625\n",
      "20 0.6601291298866272 0.5625\n",
      "25 0.6269813776016235 0.8125\n",
      "30 0.595515251159668 0.875\n",
      "35 0.5542500019073486 0.875\n",
      "40 0.5360901355743408 0.8125\n",
      "45 0.5258952379226685 0.9375\n",
      "50 0.5396168828010559 0.875\n",
      "55 0.5625669956207275 0.75\n",
      "60 0.503290057182312 0.875\n",
      "65 0.5144746899604797 0.75\n",
      "70 0.4966588318347931 0.875\n",
      "75 0.5100283622741699 0.8125\n",
      "80 0.5462980270385742 0.6875\n",
      "85 0.5015004873275757 0.9375\n",
      "90 0.4921759068965912 0.875\n",
      "95 0.5301690697669983 0.8125\n",
      "100 0.4322471618652344 0.9375\n",
      "105 0.4398854672908783 1.0\n",
      "110 0.6205551028251648 0.6875\n",
      "115 0.4555570185184479 0.9375\n",
      "120 0.43458104133605957 0.9375\n",
      "125 0.5856747031211853 0.8125\n",
      "130 0.45579853653907776 0.875\n",
      "135 0.49450209736824036 0.875\n",
      "140 0.4834059476852417 0.875\n",
      "145 0.41298091411590576 0.9375\n",
      "150 0.6239964365959167 0.6875\n",
      "155 0.42134153842926025 0.9375\n",
      "160 0.41760149598121643 1.0\n",
      "165 0.4275535047054291 0.9375\n",
      "170 0.5800575613975525 0.75\n",
      "175 0.44518887996673584 0.875\n",
      "180 0.42843857407569885 0.9375\n",
      "185 0.4298834800720215 0.9375\n",
      "190 0.46833470463752747 0.8125\n",
      "195 0.48308607935905457 0.8125\n",
      "200 0.5299521684646606 0.8125\n",
      "205 0.41276633739471436 0.9375\n",
      "210 0.4676920771598816 0.875\n",
      "215 0.4392228424549103 0.875\n",
      "220 0.4800220727920532 0.875\n",
      "225 0.5499932169914246 0.6875\n",
      "230 0.4250292479991913 0.875\n",
      "235 0.44226840138435364 0.9375\n",
      "240 0.527982771396637 0.75\n",
      "245 0.398252934217453 0.9375\n",
      "250 0.4963655173778534 0.875\n",
      "255 0.5123838186264038 0.8125\n",
      "260 0.4301964342594147 0.9375\n",
      "265 0.561985969543457 0.8125\n",
      "270 0.5266203880310059 0.75\n",
      "275 0.4799845516681671 0.8125\n",
      "280 0.3876492977142334 0.9375\n",
      "285 0.4564688503742218 0.8125\n",
      "290 0.4763532280921936 0.875\n",
      "295 0.538104772567749 0.75\n",
      "300 0.39595580101013184 0.9375\n"
     ]
    }
   ],
   "source": [
    "from transformers import AdamW\n",
    "\n",
    "#训练\n",
    "optimizer = AdamW(model.parameters(), lr=5e-4)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "model.train()\n",
    "for i, (input_ids, attention_mask, token_type_ids,\n",
    "        labels) in enumerate(loader):\n",
    "    out = model(input_ids=input_ids,\n",
    "                attention_mask=attention_mask,\n",
    "                token_type_ids=token_type_ids)\n",
    "\n",
    "    loss = criterion(out, labels)\n",
    "    loss.backward()\n",
    "    optimizer.step()\n",
    "    optimizer.zero_grad()\n",
    "\n",
    "    if i % 5 == 0:\n",
    "        out = out.argmax(dim=1)\n",
    "        accuracy = (out == labels).sum().item() / len(labels)\n",
    "\n",
    "        print(i, loss.item(), accuracy)\n",
    "\n",
    "    if i == 300:\n",
    "        break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "275dd1b6",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Using custom data configuration default\n",
      "Reusing dataset chn_senti_corp (/Users/lee/.cache/huggingface/datasets/seamew___chn_senti_corp/default/0.0.0/1f242195a37831906957a11a2985a4329167e60657c07dc95ebe266c03fdfb85)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0\n",
      "1\n",
      "2\n",
      "3\n",
      "4\n",
      "0.86875\n"
     ]
    }
   ],
   "source": [
    "#测试\n",
    "def test():\n",
    "    model.eval()\n",
    "    correct = 0\n",
    "    total = 0\n",
    "\n",
    "    loader_test = torch.utils.data.DataLoader(dataset=Dataset('validation'),\n",
    "                                              batch_size=32,\n",
    "                                              collate_fn=collate_fn,\n",
    "                                              shuffle=True,\n",
    "                                              drop_last=True)\n",
    "\n",
    "    for i, (input_ids, attention_mask, token_type_ids,\n",
    "            labels) in enumerate(loader_test):\n",
    "\n",
    "        if i == 5:\n",
    "            break\n",
    "\n",
    "        print(i)\n",
    "\n",
    "        with torch.no_grad():\n",
    "            out = model(input_ids=input_ids,\n",
    "                        attention_mask=attention_mask,\n",
    "                        token_type_ids=token_type_ids)\n",
    "\n",
    "        out = out.argmax(dim=1)\n",
    "        correct += (out == labels).sum().item()\n",
    "        total += len(labels)\n",
    "\n",
    "    print(correct / total)\n",
    "\n",
    "\n",
    "test()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
